/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
 */

package com.huawei.housekeeper.commonutils.exception;

import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;

import javax.validation.ConstraintViolationException;

import org.apache.commons.lang3.StringUtils;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.http.converter.HttpMessageNotReadableException;
import org.springframework.validation.BindException;
import org.springframework.validation.BindingResult;
import org.springframework.validation.FieldError;
import org.springframework.web.bind.MethodArgumentNotValidException;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.ResponseBody;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.huawei.housekeeper.commonutils.enums.ErrorCode;
import com.huawei.housekeeper.commonutils.result.Result;

import io.jsonwebtoken.ExpiredJwtException;
import lombok.extern.log4j.Log4j2;

/**
 * 全局异常处理
 *
 * @author y00464350
 * @since 2022-02-11
 */
@Log4j2
@ControllerAdvice
@Order(Ordered.HIGHEST_PRECEDENCE)
public class GlobalExceptionHandler {

    /**
     * 全局异常处理方法
     *
     * @param exception 异常
     * @param <T>       类型泛型
     * @return Result
     */
    @ResponseBody
    @ExceptionHandler(value = Exception.class)
    public final <T> Result<T> resultError(Exception exception) {
        Function<Exception, String> function = EXCEPTION_FUNC_MAP.get(exception.getClass());
        if (function != null) {
            return Result.createResult(ErrorCode.PARAM_INVALID.getCode(), function.apply(exception));
        }
        int code = ErrorCode.UNKNOWN.getCode();
        String message = exception.getMessage();
        if (exception instanceof ErrorCodeException) {
            ErrorCodeException business = (ErrorCodeException) exception;
            code = business.getErrorCode();
            log.error(exception.getMessage(), exception);
        } else if (exception instanceof ExpiredJwtException) {
            // token失效异常
            ExpiredJwtException tokenExpired = (ExpiredJwtException) exception;
            code = ErrorCode.TOKEN_EXPIRED.getCode();
            message = ErrorCode.TOKEN_EXPIRED.getMessage();
            log.error(tokenExpired.getMessage(), tokenExpired);
        } else if (exception instanceof JsonProcessingException) {
            // json格式异常
            JsonProcessingException jsonException = (JsonProcessingException) exception;
            code = ErrorCode.JSON_EXECPTION.getCode();
            log.error(jsonException.getMessage(), jsonException);
        } else if (exception instanceof ArgumentException) {
            // 自定义参数异常
            ArgumentException argument = (ArgumentException) exception;
            code = argument.getCode();
            log.warn(exception.getMessage(), exception);
        } else if (exception instanceof NullPointerException) {
            // 自定义参数异常
            code = ErrorCode.NULL_POINTTER.getCode();
            message = ErrorCode.NULL_POINTTER.getMessage();
            log.warn(exception.getMessage(), exception);
        } else {
            // 输出异常堆栈信息
            log.error(exception.getMessage(), exception);
        }
        message = StringUtils.isEmpty(message) ? ErrorCode.UNKNOWN.getMessage() : message;
        return Result.createResult(code, message);
    }

    /**
     * 参数校验异常（BindException）处理函数
     */
    private static final Function<Exception, String> BIND_FUNC = exception -> {
        BindingResult bindingResult = ((BindException) exception).getBindingResult();
        return getErrorMessage(exception, bindingResult);
    };

    /**
     * 参数校验异常（MethodArgumentNotValidException）处理函数
     */
    private static final Function<Exception, String> METHOD_ARGUMENT_NOT_VALID_FUNC = exception -> {
        BindingResult bindingResult = ((MethodArgumentNotValidException) exception).getBindingResult();
        return getErrorMessage(exception, bindingResult);
    };

    /**
     * 参数校验异常（ConstraintViolationException）处理函数
     */
    private static final Function<Exception, String> CONSTRAIN_VIOLATION_FUNC = exception -> {
        ConstraintViolationException argException = (ConstraintViolationException) exception;
        log.warn(argException.getMessage(), argException);
        return argException.getMessage();
    };

    /**
     * 请求校验异常（HttpMessageNotReadableException）处理函数
     */
    private static final Function<Exception, String> HTTP_MESSAGE_NOT_READABLE_FUNC = exception -> {
        log.warn(exception.getMessage(), exception);
        HttpMessageNotReadableException argException = (HttpMessageNotReadableException) exception;
        return argException.getCause().getMessage();
    };

    /**
     * 自定义异常响应map容器
     */
    private static final Map<Class<?>, Function<Exception, String>> EXCEPTION_FUNC_MAP = new HashMap<Class<?>, Function<Exception, String>>() {
        private static final long serialVersionUID = 1L;

        {
            put(BindException.class, BIND_FUNC);
            put(MethodArgumentNotValidException.class, METHOD_ARGUMENT_NOT_VALID_FUNC);
            put(ConstraintViolationException.class, CONSTRAIN_VIOLATION_FUNC);
            put(HttpMessageNotReadableException.class, HTTP_MESSAGE_NOT_READABLE_FUNC);
        }
    };

    /**
     * 异常返回值设置
     *
     * @param exception     异常
     * @param bindingResult 校验异常结果
     * @return 异常描述
     */
    private static String getErrorMessage(Exception exception, BindingResult bindingResult) {
        log.warn(exception.getMessage(), exception);
        FieldError fieldError = bindingResult.getFieldError();
        return "属性:[" + fieldError.getField() + "] 错误提示 -> " + fieldError.getDefaultMessage();
    }
}
